from torch.utils.data import Dataset
class TensorDataset(Dataset):
    """
    data_tensor:sample data.
    target_tensor:labels
    """
    def __init__(self,data,label):
        assert data.size(0) == label.size(0)
        self.data = data
        self.label = label
    def __getitem__(self, idx):
        return self.data[idx],self.label[idx]
    def __len__(self):
        return len(self.label)